{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d8544be7",
   "metadata": {},
   "source": [
    "# Training a CNN to Predict Battery Cycle Life\n",
    "*Noah Paulson, Argonne National Laboratory*\n",
    "\n",
    "\n",
    "In this notebook we will collect data, visualize the machine learning features, define a convolutional neural network architecture, train it, and evaluate the results.\n",
    "\n",
    "### Import packages, define data loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8ca0704",
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "import time\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from urllib import request"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92b53dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(logy=True, output=False):\n",
    "\n",
    "    remote_url = 'https://drive.google.com/uc?export=download&id=1orXDHcq2TzRLhGkHyWVtsuGn7_10pSAn'\n",
    "    fname = 'energy_capacity.h5'\n",
    "    request.urlretrieve(remote_url, fname)\n",
    "    \n",
    "    # load the raw data from hdf5 file\n",
    "    f = h5py.File(fname, 'r')\n",
    "    \n",
    "    # evaluated cycles and voltages\n",
    "    cyceval = f['evaluated cycles']\n",
    "    veval = f['evaluated voltages']\n",
    "    \n",
    "    Xin = []  # feature array\n",
    "    cycfailV = np.array([])  # failure cycle array\n",
    "    batchV = np.array([])  # batch array\n",
    "\n",
    "    anameL = f['cell data'].keys()\n",
    "    ncell = len(anameL)\n",
    "    for aname in anameL:\n",
    "\n",
    "        g = f['cell data'][aname]\n",
    "\n",
    "        cycfail = g['cycles to 90pct capacity'][0]\n",
    "        cycfailV = np.append(cycfailV, cycfail)\n",
    "\n",
    "        VvC = g['discharging voltage vs capacity']\n",
    "\n",
    "        # our features are based on the difference in capacity between cycles\n",
    "        farr = VvC[1:, :] - VvC[0, :]  \n",
    "        Xin += [farr]\n",
    "\n",
    "        if 'Batch1' in aname:\n",
    "            batchV = np.append(batchV, 1)\n",
    "        elif 'Batch2' in aname:\n",
    "            batchV = np.append(batchV, 2)\n",
    "        elif 'Batch3' in aname:\n",
    "            batchV = np.append(batchV, 3)\n",
    "        else:\n",
    "            batchV = np.append(batchV, 0)\n",
    "\n",
    "    nfeat = farr.size\n",
    "    ncyc = len(cyceval)\n",
    "\n",
    "    Xin = np.stack(Xin)\n",
    "    if output:\n",
    "        print('number of cells:', ncell)\n",
    "        print('number of features:', nfeat)\n",
    "        print(Xin.shape)\n",
    "\n",
    "    # create the train, validation, and test sets\n",
    "\n",
    "    dsetL = ['trn', 'val', 'tst']\n",
    "\n",
    "    iD = {}\n",
    "\n",
    "    # batch 3 is reserved for test\n",
    "    iD['tst'] = batchV == 3\n",
    "    ntst = np.sum(iD['tst'])\n",
    "\n",
    "    irst = np.invert(iD['tst'])\n",
    "    nrst = ncell - ntst\n",
    "\n",
    "    # validation and training sets are pulled from batches 1 and 2\n",
    "    nval = np.int16(0.2*nrst)\n",
    "\n",
    "    irnd = np.random.rand(ncell)\n",
    "    irnd[iD['tst']] = 0\n",
    "    top_nval = np.argsort(irnd)[-nval:]\n",
    "\n",
    "    iD['val'] = np.zeros((ncell,), dtype='bool')\n",
    "    iD['val'][top_nval] = True\n",
    "    iD['trn'] = np.invert(iD['val'])*np.invert(iD['tst'])\n",
    "\n",
    "    X = {}  # features dictionary\n",
    "    y = {}  # failure cycles dictionary\n",
    "    n = {}  # number of cells dictionary\n",
    "    for dset in dsetL:\n",
    "        X[dset] = Xin[iD[dset], ..., None]\n",
    "        if logy:\n",
    "            y[dset] = np.log10(cycfailV[iD[dset]])\n",
    "        else:\n",
    "            y[dset] = cycfailV[iD[dset]]\n",
    "        n[dset] = np.sum(iD[dset])\n",
    "\n",
    "    # define weights for each cell based on the relative number of\n",
    "    # cells available for that cycle-life bin. For example, there\n",
    "    # are many low cycle-life cells so those have low weights, and\n",
    "    # there are relatively few high cycle-life cells so those have\n",
    "    # higher weights.\n",
    "    bins = 4\n",
    "    hist, bin_edges = np.histogram(y['trn'], bins=bins, density=True)\n",
    "\n",
    "    invhist = np.mean(hist)/hist\n",
    "\n",
    "    weights = np.ones(y['trn'].shape)\n",
    "    for ii in range(bins):\n",
    "        lwr = bin_edges[ii]\n",
    "        upr = bin_edges[ii+1]\n",
    "        if ii == 0:\n",
    "            sel = (lwr <= y['trn'])*(y['trn'] <= upr)\n",
    "        else:\n",
    "            sel = (lwr < y['trn'])*(y['trn'] <= upr)\n",
    "        weights[sel] = invhist[ii]\n",
    "\n",
    "\n",
    "    if output:\n",
    "        print('bin_edges:', bin_edges)\n",
    "        print('original training densities:', hist)\n",
    "        print('category weights:', invhist)\n",
    "        print('y_trn weights:', weights)\n",
    "        histw, bew = np.histogram(\n",
    "            y['trn'], bins=bins, density=True, weights=weights)\n",
    "        print('new densities:', histw)\n",
    "\n",
    "    return X, y, n, iD, weights, cyceval, veval"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1e3f0e1",
   "metadata": {},
   "source": [
    "### Define key variables for the fitting and extract the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6063130",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 200\n",
    "logy=True  # apply a log transform to the cycle lives as a normalization scheme\n",
    "output=True\n",
    "batch_size=32\n",
    "dsetL = ['trn', 'val', 'tst']\n",
    "\n",
    "sns.set()\n",
    "np.random.seed(1)\n",
    "\n",
    "X, y, n, iD, weights, cyceval, veval = get_data(logy, output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd81b88",
   "metadata": {},
   "source": [
    "### Visualize features for machine learning\n",
    "\n",
    "In this work, we define the features as the difference in capacity at different voltages during discharge between cycle X and cycle 10. Cycle X can be in the set [20, 30, 40, 50, 60 ,70, 80, 90, 100]. There are a total of 180 features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64110f2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = X['trn'][0, ..., 0]\n",
    "f, ax = plt.subplots()\n",
    "img = ax.imshow(arr)\n",
    "plt.colorbar(img, ax=ax, label='Q(Voltage, Cycle_X) - \\nQ(Voltage, Cycle_10)   (Ah)')\n",
    "ax.set_xticks(np.arange(arr.shape[1])[::2])\n",
    "ax.set_yticks(np.arange(arr.shape[0]))\n",
    "yticklabels = cyceval[()][1:]\n",
    "yticklabels = yticklabels.astype(str)\n",
    "ax.set_xticklabels(np.round(veval[::2], 1))\n",
    "ax.set_yticklabels(yticklabels)\n",
    "ax.grid(False)\n",
    "plt.xlabel('Voltage')\n",
    "plt.ylabel('Cycle')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2acd07f9",
   "metadata": {},
   "source": [
    "### Define the model\n",
    "\n",
    "We define the convolutional network with two convolutional layers and one dense layer. Feel free to explore the number of filters/neurons, the shape of the filters, and the number of different types of layers. It is also interesting to see the effect of learning rate and loss metric on fitting and performance. We print a summary of the model that shows the types of layers, the shape of their outputs, and the number of parameters associated with the layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33169007",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = keras.Input(shape=(X['trn'].shape[1], X['trn'].shape[2], 1), name='farr')\n",
    "x = keras.layers.Conv2D(16, (3, 3), activation='relu')(inputs)\n",
    "x = keras.layers.Conv2D(8, (3, 3), activation='relu')(inputs)\n",
    "x = keras.layers.Flatten()(x)\n",
    "# x = keras.layers.Dropout(0.2)(x)\n",
    "# x = keras.layers.BatchNormalization()(x)\n",
    "x = keras.layers.Dense(16, activation='relu')(x)\n",
    "outputs = keras.layers.Dense(1)(x)\n",
    "model = keras.Model(inputs=inputs, outputs=outputs, name='functional_1')\n",
    "\n",
    "model.summary()\n",
    "\n",
    "model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01, amsgrad=False),\n",
    "              loss='MSE',\n",
    "              metrics=[])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04ed5136",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321f8637",
   "metadata": {},
   "outputs": [],
   "source": [
    "cb2 = keras.callbacks.EarlyStopping(\n",
    "    monitor='val_loss',\n",
    "    patience=50,\n",
    "    restore_best_weights=True)\n",
    "\n",
    "st = time.time()\n",
    "\n",
    "history = model.fit(X['trn'], y['trn'],\n",
    "    epochs=epochs, batch_size=batch_size,\n",
    "    sample_weight=weights, callbacks=[cb2],\n",
    "    validation_data=(X['val'], y['val']))\n",
    "\n",
    "print('fit time:', np.round(time.time()-st))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fa7b826",
   "metadata": {},
   "source": [
    "### Make predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b32e75e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ypred = {}\n",
    "for dset in dsetL:\n",
    "    ypred[dset] = np.squeeze(model.predict(X[dset]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "142d14a4",
   "metadata": {},
   "source": [
    "### Plot loss versus epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee1956d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure('metrics_vs_epoch')\n",
    "nhist = len(history.history['loss'])\n",
    "plt.semilogy(range(1, nhist+1), history.history['loss'],\n",
    "             'b-', alpha=0.5, label='training')\n",
    "plt.semilogy(range(1, nhist+1), history.history['val_loss'],\n",
    "             'r-', alpha=0.5, label='validation')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('loss')\n",
    "plt.legend()\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76ca4dc1",
   "metadata": {},
   "source": [
    "### Parity Plot\n",
    "\n",
    "The parity plot shows the experimental cycle lives versus the predicted cycle lives. Accurate predictions are close to the 45 degree centerline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "762645e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "minL = []\n",
    "maxL = []\n",
    "cL = sns.color_palette('cubehelix', 5)[1:-1]\n",
    "mL = ['o', 's', 'D']\n",
    "lblL = ['train', 'validate', 'test']\n",
    "plt.figure(num='parity-plot', figsize=(5.5, 5))\n",
    "for dset, c, m, lbl in zip(dsetL, cL, mL, lblL):\n",
    "    if logy:\n",
    "        yP = 10**ypred[dset]\n",
    "        yT = 10**y[dset]\n",
    "    else:\n",
    "        yP = ypred[dset]\n",
    "        yT = y[dset]\n",
    "\n",
    "    minL += [np.min(yP), np.min(yT)]\n",
    "    maxL += [np.max(yP), np.max(yT)]\n",
    "\n",
    "    plt.plot(\n",
    "        yT, yP,\n",
    "        color=c, marker=m,\n",
    "        ls='', ms=5, alpha=0.6, label=lbl)\n",
    "\n",
    "ymin = np.min(minL)\n",
    "ymax = np.max(maxL)\n",
    "yrng = ymax - ymin\n",
    "lwr = ymin - .05*yrng\n",
    "upr = ymax + .05*yrng\n",
    "\n",
    "plt.plot([lwr, upr], [lwr, upr], 'k-')\n",
    "plt.xlim([lwr, upr])\n",
    "plt.ylim([lwr, upr])\n",
    "plt.xlabel('experimental cycles to failure')\n",
    "plt.ylabel('predicted cycles to failure')\n",
    "plt.legend()\n",
    "plt.gca().set_aspect('equal', adjustable='box')\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b2113db",
   "metadata": {},
   "source": [
    "### Error Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5f4ae66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate errors and load into dictionary\n",
    "metricD = {}\n",
    "for dset in dsetL:\n",
    "\n",
    "    metricD[dset + '_loss'] = np.mean((ypred[dset]-y[dset])**2)\n",
    "\n",
    "    if logy:\n",
    "        mapre = 100*np.mean(np.abs(10**ypred[dset]-10**y[dset])/10**y[dset])\n",
    "        mape = 100*np.mean(np.abs(10**ypred[dset]-10**y[dset])/10**y[dset].max())\n",
    "    else:\n",
    "        mapre = 100*np.mean(np.abs(ypred[dset]-y[dset])/y[dset])\n",
    "        mape = 100*np.mean(np.abs(ypred[dset]-y[dset])/y[dset].max())\n",
    "\n",
    "    metricD[dset + '_mapre'] = mapre \n",
    "    metricD[dset + '_mape'] = mape\n",
    "\n",
    "# print results\n",
    "for dset in dsetL:\n",
    "    print('\\n')\n",
    "    print(dset, 'mean squared error:',\n",
    "          np.round(metricD[dset + '_loss'], 4))\n",
    "    print(dset, 'mean absolute percent relative error:',\n",
    "          np.round(metricD[dset + '_mapre'], 1))\n",
    "    print(dset, 'mean absolute percent error:',\n",
    "          np.round(metricD[dset + '_mape'], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "917ae945",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
